{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting patch histograms\n",
    "Note: these will only run once you've run `patches.py` and `segmenter.py` on the experiment directory,\n",
    "e.g. see `scripts/04_eval_patches_gen_models.sh` for an example, as it pulls the patch data from the respective experiment directories. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from utils import rfutil, imutil, show\n",
    "import os\n",
    "from collections import Counter\n",
    "import matplotlib as mpl\n",
    "from PIL import Image\n",
    "import cv2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_patches(path, num_patches=8, mult=1, outdir='plots', show_legend=True, ylim=None):\n",
    "\n",
    "    outname = path.split('/')[-3]\n",
    "\n",
    "    # baseline - fakes\n",
    "    baseline = Counter()\n",
    "    with open(os.path.join(path, 'fakes_easiest_clusters', 'baseline.txt')) as f:\n",
    "        for line in f:\n",
    "            cluster = line.split(' ')[1].rsplit(',')[0]\n",
    "            count = int(line.split(' ')[2])\n",
    "            baseline[cluster] += count\n",
    "    print(baseline)\n",
    "    total = sum(v for v in baseline.values())\n",
    "    print(total)\n",
    "\n",
    "    # patch counts - fakes\n",
    "    counts = Counter()\n",
    "    cluster_order_fakes = dict()\n",
    "    with open(os.path.join(path, 'fakes_easiest_clusters', 'counts.txt')) as f:\n",
    "        for line in f:\n",
    "            index = int(line.split(':')[0])\n",
    "            cluster = line.split(' ')[1].rsplit(',')[0]\n",
    "            count = int(line.split(' ')[2])\n",
    "            counts[cluster] += count\n",
    "            cluster_order_fakes[cluster] = index\n",
    "    print(counts)\n",
    "    assert(sum(v for v in counts.values()) == total)\n",
    "\n",
    "    # baseline - reals\n",
    "    with open(os.path.join(path, 'reals_easiest_clusters', 'baseline.txt')) as f:\n",
    "        for line in f:\n",
    "            cluster = line.split(' ')[1].rsplit(',')[0]\n",
    "            count = int(line.split(' ')[2])\n",
    "            baseline[cluster] += count\n",
    "    print(baseline)\n",
    "    total = sum(v for v in baseline.values())\n",
    "\n",
    "    # patch counts - reals\n",
    "    cluster_order_reals = dict()\n",
    "    with open(os.path.join(path, 'reals_easiest_clusters', 'counts.txt')) as f:\n",
    "        for line in f:\n",
    "            index = int(line.split(':')[0])\n",
    "            cluster = line.split(' ')[1].rsplit(',')[0]\n",
    "            count = int(line.split(' ')[2])\n",
    "            counts[cluster] += count\n",
    "            cluster_order_reals[cluster] = index\n",
    "    print(counts)\n",
    "    assert(sum(v for v in counts.values()) == total)\n",
    "    print('total patches: %d' % total)\n",
    "\n",
    "    mpl.rcParams.update({'font.size': 18})\n",
    "    f, ax = plt.subplots(1, 1, figsize=(8,4))\n",
    "    sns.set_palette('muted')\n",
    "    sort_labels = sorted(counts, key=counts.get)[::-1]\n",
    "    ax.bar(np.arange(1, len(sort_labels)*2+1, 2)-0.4, [counts[l] for l in sort_labels], label='top patch')\n",
    "    ax.set_xticks(np.arange(1, len(sort_labels)*2+1, 2))\n",
    "    ax.set_xticklabels(sort_labels, rotation='vertical')\n",
    "    ax.bar(np.arange(1, len(sort_labels)*2+1, 2)+0.4, [baseline[l] for l in sort_labels], label='random patch', alpha=0.5)\n",
    "    ax.set_ylabel('count')\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(ylim)\n",
    "    if show_legend:\n",
    "        ax.legend()\n",
    "    ax.grid(alpha=0.3)\n",
    "    os.makedirs(outdir, exist_ok=True)\n",
    "    f.savefig(os.path.join(outdir, 'histogram_%s.pdf' % outname), bbox_inches='tight')\n",
    "\n",
    "    # fake patches\n",
    "    for label in sort_labels[:3]:\n",
    "        index = cluster_order_fakes[label]\n",
    "        cluster_info = np.load(os.path.join(path, 'fakes_easiest_clusters', 'cluster_%d.npz' % index))\n",
    "        rfs = rfutil.find_rf_patches(cluster_info['which_model_netD'], cluster_info['finesize'])\n",
    "        with open(os.path.join(path, 'fakes_easiest_clusters', 'cluster_%d.txt' % index)) as f:\n",
    "            cluster_files = [line.strip() for line in f]\n",
    "        patches = (cluster_info['patch'][-num_patches:][::-1] * 0.5) + 0.5\n",
    "        grid = imutil.imgrid(np.uint8(patches * 255), pad=0, cols=num_patches//2)\n",
    "        fake_patches = Image.fromarray(grid)\n",
    "        w,h = fake_patches.size\n",
    "        show.a(['fake patch', fake_patches.resize((int(mult*w), int(mult*h)), Image.LANCZOS)])\n",
    "        top_im = Image.open(os.path.join('..', cluster_files[-1])).resize(\n",
    "            (cluster_info['finesize'], cluster_info['finesize']), Image.LANCZOS)\n",
    "        top_pos = cluster_info['pos'][-1]\n",
    "        top_im_border = np.array(top_im)\n",
    "        slice_y, slice_x = rfs[top_pos[0], top_pos[1]]\n",
    "        cv2.rectangle(top_im_border, (slice_x.start, slice_y.start),\n",
    "                      (slice_x.stop, slice_y.stop), color=[255, 255, 0], thickness=3)\n",
    "        show.a(['fake image', Image.fromarray(top_im_border).resize((150, 150), Image.LANCZOS)])\n",
    "\n",
    "\n",
    "        index = cluster_order_reals[label]\n",
    "        cluster_info = np.load(os.path.join(path, 'reals_easiest_clusters', 'cluster_%d.npz' % index))\n",
    "        rfs = rfutil.find_rf_patches(cluster_info['which_model_netD'], cluster_info['finesize'])\n",
    "        with open(os.path.join(path, 'reals_easiest_clusters', 'cluster_%d.txt' % index)) as f:\n",
    "            cluster_files = [line.strip() for line in f]\n",
    "        patches = (cluster_info['patch'][-num_patches:][::-1] * 0.5) + 0.5\n",
    "        grid = imutil.imgrid(np.uint8(patches * 255), pad=0, cols=num_patches//2)\n",
    "        real_patches = Image.fromarray(grid)\n",
    "        w,h = real_patches.size\n",
    "        show.a(['real patch', real_patches.resize((int(mult*w), int(mult*h)), Image.LANCZOS)])\n",
    "        top_im = Image.open(os.path.join('..', cluster_files[-1])).resize(\n",
    "            (cluster_info['finesize'], cluster_info['finesize']), Image.LANCZOS)\n",
    "        top_pos = cluster_info['pos'][-1]\n",
    "        top_im_border = np.array(top_im)\n",
    "        slice_y, slice_x = rfs[top_pos[0], top_pos[1]]\n",
    "        cv2.rectangle(top_im_border, (slice_x.start, slice_y.start),\n",
    "                      (slice_x.stop, slice_y.stop), color=[255, 255, 0], thickness=3)\n",
    "        show.a(['real image', Image.fromarray(top_im_border).resize((150, 150), Image.LANCZOS)])\n",
    "        show.flush()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fully generative models\n",
    "Uncomment one of the following lines to visualize patch histograms for that specific experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celebahq-pgan-pretrained/'; mult=1\n",
    "# prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/celebahq-sgan-pretrained/'; mult=0.5\n",
    "# prefix = '../results/gp1d-gan-samplesonly_seed0_xception_block1_constant_p50/test/epoch_bestval/celebahq-glow-pretrained'; mult=2\n",
    "# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celeba-gmm'; mult=1\n",
    "# prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/ffhq-pgan'; mult=1\n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/ffhq-sgan2/'; mult=0.5\n",
    "\n",
    "plot_patches(os.path.join(prefix,'patches_top10000/clusters'), mult=mult)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Faceforensics\n",
    "You can also do something similar with the faceforensics datasets, but need to preprocess the frames\n",
    "(following `scripts/00_data_processing_faceforensics_aligned_frames.sh`) and then run the patch experiment\n",
    "pipeline (e.g. `following 04_eval_patches_faceforensics_F2F.sh`) first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mult = 1\n",
    "\n",
    "# # Train on Face2Face\n",
    "# outdir = 'plots/Face2Face'\n",
    "# # prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/DF/'; mult=1; show_legend=False\n",
    "# # prefix = '../results/gp5-faceforensics-f2f_seed0_xception_block1_constant_p50/test/epoch_bestval/NT/'; mult=2; show_legend=False\n",
    "# prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/F2F/'; mult=1; show_legend=True\n",
    "\n",
    "# plot_patches(os.path.join(prefix,'patches_top10000/clusters'), mult=mult, outdir=outdir, show_legend=show_legend)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "forgery",
   "language": "python",
   "name": "forgery"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}